{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21aab8ac",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "The following env works\n",
    " - torch: 1.9.1\n",
    " - torchvision: 0.10.0+cu102\n",
    " - torch_sparse: 0.6.12\n",
    " \n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3cc9dcc6-9d43-47e9-a4d3-0ba29beb425e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "PATH = '/home/chaoqiy2/github/PyHealth'\n",
    "os.chdir(PATH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2e5c4ec0-0887-48c7-8f97-1ccafb401f2d",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/chaoqiy2/miniconda3/envs/moltext/lib/python3.7/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from pyhealth.sampler import NeighborSampler\n",
    "from pyhealth.models import Graph_TorchvisionModel\n",
    "from pyhealth.models import GCN\n",
    "from torchvision import transforms\n",
    "from pyhealth.datasets import COVID19CXRDataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a80901ce-1077-4272-a087-f728683bcdad",
   "metadata": {},
   "source": [
    "## Load Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e1c54991-ed56-400f-a82e-3609ef0f953c",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from pyhealth.datasets import COVID19CXRDataset\n",
    "\n",
    "root = \"/srv/local/data/COVID-19_Radiography_Dataset\"\n",
    "base_dataset = COVID19CXRDataset(root)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "46475e1d-e31a-48f1-848d-457a9c2eee48",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "COVID19CXRClassification(task_name='COVID19CXRClassification', input_schema={'path': 'image'}, output_schema={'label': 'label'})"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "base_dataset.default_task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "ad7bc53d-66db-499d-bddd-97c3e6184ad7",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Generating samples for COVID19CXRClassification: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21165/21165 [00:00<00:00, 1282116.21it/s]\n"
     ]
    }
   ],
   "source": [
    "sample_dataset = base_dataset.set_task()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b6caf164-9875-448a-938f-bae758b439d9",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from torchvision import transforms\n",
    "\n",
    "\n",
    "transform = transforms.Compose([\n",
    "    transforms.Lambda(lambda x: x if x.shape[0] == 3 else x.repeat(3, 1, 1)),\n",
    "    transforms.Resize((224, 224)),\n",
    "    transforms.Normalize(mean=[0.5862785803043838], std=[0.27950088968644304])\n",
    "])\n",
    "\n",
    "\n",
    "def encode(sample):\n",
    "    sample[\"path\"] = transform(sample[\"path\"])\n",
    "    return sample\n",
    "\n",
    "\n",
    "sample_dataset.set_transform(encode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "6c46515c-9ff3-4ff5-b942-cf43666d3a2d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from pyhealth.datasets import split_by_sample\n",
    "\n",
    "# Get Index of train, valid, test set\n",
    "train_index, val_index, test_index = split_by_sample(\n",
    "    dataset=sample_dataset,\n",
    "    ratios=[0.7, 0.1, 0.2],\n",
    "    get_index = True\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "a5c1e65c-ac81-421b-b265-7404fcab404c",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| Uniform Initialization\n",
      "| Uniform Initialization\n"
     ]
    }
   ],
   "source": [
    "import ssl\n",
    "ssl._create_default_https_context = ssl._create_unverified_context\n",
    "\n",
    "model = Graph_TorchvisionModel(\n",
    "        dataset=sample_dataset,\n",
    "        feature_keys=[\"path\"],\n",
    "        label_key=\"label\",\n",
    "        mode=\"multiclass\",\n",
    "        model_name=\"resnet18\",\n",
    "        model_config={},\n",
    "        gnn_config={\"input_dim\": 256, \"hidden_dim\": 128},\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "c1640ebc-9fcb-49e1-9f2c-e7113ca85c62",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Build graph\n",
    "# Set random = True will build random graph data\n",
    "graph = model.build_graph(sample_dataset, random = True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "567e6091-87c7-4b37-a78d-402bc7c96939",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Define Sampler as Dataloader\n",
    "train_dataloader = NeighborSampler(sample_dataset, graph[\"edge_index\"], node_idx=train_index, sizes=[15, 10], batch_size=64, shuffle=True, num_workers=12)\n",
    "\n",
    "# We sample all edges connected to target node for validation and test (Sizes = [-1, -1])\n",
    "valid_dataloader = NeighborSampler(sample_dataset, graph[\"edge_index\"], node_idx=val_index, sizes=[-1, -1], batch_size=64, shuffle=False, num_workers=12)\n",
    "test_dataloader = NeighborSampler(sample_dataset, graph[\"edge_index\"], node_idx=test_index, sizes=[-1, -1], batch_size=64, shuffle=False, num_workers=12)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "c54b30ec-e7fc-402d-aba5-7c361e95cd60",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Graph_TorchvisionModel(\n",
      "  (model): ResNet(\n",
      "    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
      "    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "    (relu): ReLU(inplace=True)\n",
      "    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
      "    (layer1): Sequential(\n",
      "      (0): BasicBlock(\n",
      "        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        (relu): ReLU(inplace=True)\n",
      "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "      (1): BasicBlock(\n",
      "        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        (relu): ReLU(inplace=True)\n",
      "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (layer2): Sequential(\n",
      "      (0): BasicBlock(\n",
      "        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        (relu): ReLU(inplace=True)\n",
      "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        (downsample): Sequential(\n",
      "          (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "      (1): BasicBlock(\n",
      "        (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        (relu): ReLU(inplace=True)\n",
      "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (layer3): Sequential(\n",
      "      (0): BasicBlock(\n",
      "        (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        (relu): ReLU(inplace=True)\n",
      "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        (downsample): Sequential(\n",
      "          (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "      (1): BasicBlock(\n",
      "        (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        (relu): ReLU(inplace=True)\n",
      "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (layer4): Sequential(\n",
      "      (0): BasicBlock(\n",
      "        (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
      "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        (relu): ReLU(inplace=True)\n",
      "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        (downsample): Sequential(\n",
      "          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
      "          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        )\n",
      "      )\n",
      "      (1): BasicBlock(\n",
      "        (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "        (relu): ReLU(inplace=True)\n",
      "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
      "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
      "      )\n",
      "    )\n",
      "    (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))\n",
      "    (fc): Linear(in_features=512, out_features=256, bias=True)\n",
      "  )\n",
      "  (gnn): GCN(\n",
      "    (gc1): GraphConvolution (256 -> 128)\n",
      "    (gc2): GraphConvolution (128 -> 4)\n",
      "  )\n",
      ")\n",
      "Metrics: None\n",
      "Device: cpu\n",
      "\n",
      "Training:\n",
      "Batch size: 64\n",
      "Optimizer: <class 'torch.optim.adam.Adam'>\n",
      "Optimizer params: {'lr': 0.001}\n",
      "Weight decay: 0.0\n",
      "Max grad norm: None\n",
      "Val dataloader: NeighborSampler(sizes=[-1, -1])\n",
      "Monitor: accuracy\n",
      "Monitor criterion: max\n",
      "Epochs: 1\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 0 / 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 232/232 [18:34<00:00,  4.80s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Train epoch-0, step-232 ---\n",
      "loss: 1.3025\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "Evaluation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 34/34 [01:19<00:00,  2.34s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Eval epoch-0, step-232 ---\n",
      "accuracy: 0.4872\n",
      "f1_macro: 0.1643\n",
      "f1_micro: 0.4872\n",
      "loss: 1.2535\n",
      "New best accuracy score (0.4872) at epoch-0, step-232\n",
      "Loaded best model\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "from pyhealth.trainer import Trainer\n",
    "\n",
    "resnet_trainer = Trainer(model=model, device=\"cpu\")\n",
    "resnet_trainer.train(\n",
    "    train_dataloader=train_dataloader,\n",
    "    val_dataloader=valid_dataloader,\n",
    "    epochs=1,\n",
    "    monitor=\"accuracy\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "e1d2e71f-2a33-4187-b145-2161172821a9",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Evaluation: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 67/67 [02:41<00:00,  2.42s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'accuracy': 0.4786590097780537, 'f1_macro': 0.1618557783142647, 'f1_micro': 0.4786590097780537, 'loss': 1.256981566770753}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "print(resnet_trainer.evaluate(test_dataloader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "26253fef",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
